package hex.generic;

import hex.Model;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.SharedTree;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

import java.io.File;
import java.io.FileOutputStream;
import java.util.Arrays;

import static hex.genmodel.utils.DistributionFamily.AUTO;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

@RunWith(Parameterized.class)
public class GenericMojoConversionTest extends TestUtil {

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Parameterized.Parameters
    public static Iterable<?> data() {
        GBMModel.GBMParameters parms1 = new GBMModel.GBMParameters();
        parms1._distribution = AUTO;
        parms1._response_column = "IsDepDelayed";

        GBMModel.GBMParameters parms2 = new GBMModel.GBMParameters();
        parms2._distribution = AUTO;
        parms2._response_column = "IsDepDelayed";
        parms2._balance_classes = true;

        DRFModel.DRFParameters parms3 = new DRFModel.DRFParameters();
        parms3._distribution = AUTO;
        parms3._response_column = "IsDepDelayed";

        DRFModel.DRFParameters parms4 = new DRFModel.DRFParameters();
        parms4._distribution = AUTO;
        parms4._response_column = "IsDepDelayed";
        parms4._balance_classes = true;

        GBMModel.GBMParameters parms5 = new GBMModel.GBMParameters();
        parms5._distribution = DistributionFamily.gaussian;
        parms5._response_column = "Distance";

        GBMModel.GBMParameters parms6 = new GBMModel.GBMParameters();
        parms6._distribution = DistributionFamily.tweedie;
        parms6._response_column = "Distance";

        return Arrays.asList(parms1, parms2, parms3, parms4, parms5, parms6);
    }

    @Parameterized.Parameter
    public SharedTreeModel.SharedTreeParameters parms;

    @Test
    @SuppressWarnings("unchecked")
    public void testConvertTreeMojoToPojo() throws Exception {
        try {
            Scope.enter();
            // Create new GBM model
            final Frame trainingFrame = parseTestFile("./smalldata/testng/airlines_train.csv");
            Scope.track(trainingFrame);
            final Frame testFrame = parseTestFile("./smalldata/testng/airlines_test.csv");
            Scope.track(testFrame);
            parms._train = trainingFrame._key;
            parms._ntrees = 3;

            SharedTree<?, ?, ?> job = parms instanceof GBMModel.GBMParameters ? 
                    new GBM((GBMModel.GBMParameters) parms) : new DRF((DRFModel.DRFParameters) parms);
            final Model<?, ?, ?> model = job.trainModel().get();
            assertNotNull(model);
            Scope.<Model>track_generic(model);

            String pojoCode = model.toJava(false, true).replaceFirst("AUTOGENERATED.*", "");

            final File originalModelMojoFile = File.createTempFile("mojo", "zip");
            model.getMojo().writeTo(new FileOutputStream(originalModelMojoFile));

            Frame scoredOriginal = model.score(testFrame);
            Scope.track(scoredOriginal);

            model.delete();

            GenericModelParameters p = new GenericModelParameters();
            p._path = originalModelMojoFile.getAbsolutePath();
            GenericModel generic = (GenericModel) Generic.make(p, (Key) model._key).trainModel().get();
            Scope.track_generic(generic);

            String genericPojoCode = generic.toJava(false, true).replaceFirst("AUTOGENERATED.*", "");
            assertEquals(pojoCode, genericPojoCode);

            Frame scoredGeneric = generic.score(testFrame);
            Scope.track(scoredGeneric);

            assertFrameEquals(scoredOriginal, scoredGeneric, 0);
        } finally {
            Scope.exit();
        }
    }

}
